# This file is based on the CompGCN author's implementation
# <https://github.com/malllabiisc/CompGCN/blob/master/helper.py>.
# It implements the operation of circular convolution in the ccorr function and an additional in_out_norm function for norm computation.

import torch as th
import dgl

def com_mult(a, b):
	r1, i1 = a[..., 0], a[..., 1]
	r2, i2 = b[..., 0], b[..., 1]
	return th.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim = -1)


def conj(a):
	a[..., 1] = -a[..., 1]
	return a


def ccorr(a, b):
	"""
	Compute circular correlation of two tensors.
	Parameters
	----------
	a: Tensor, 1D or 2D
	b: Tensor, 1D or 2D

	Notes
	-----
	Input a and b should have the same dimensions. And this operation supports broadcasting.

	Returns
	-------
	Tensor, having the same dimension as the input a.
	"""
	return th.irfft(com_mult(conj(th.rfft(a, 1)), th.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],))

#identify in/out edges, compute edge norm for each and store in edata
def in_out_norm(graph):
	src, dst, EID = graph.edges(form='all')
	graph.edata['norm'] = th.ones(EID.shape[0]).to(graph.device)

	in_edges_idx = th.nonzero(graph.edata['in_edges_mask'], as_tuple=False).squeeze()
	out_edges_idx = th.nonzero(graph.edata['out_edges_mask'], as_tuple=False).squeeze()

	for idx in [in_edges_idx, out_edges_idx]:
		u, v = src[idx], dst[idx]
		deg = th.zeros(graph.num_nodes()).to(graph.device)
		n_idx, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True)
		deg[n_idx]=count.float()
		deg_inv	= deg.pow(-0.5)							# D^{-0.5}
		deg_inv[deg_inv	== float('inf')] = 0
		norm = deg_inv[u] * deg_inv[v]
		graph.edata['norm'][idx] = norm
	graph.edata['norm'] = graph.edata['norm'].unsqueeze(1)

	return graph
	
